package ffte.test

import spinal.core._
import spinal.core.sim._

import scala.util.Random._

import ffte.evaluate.cases.{Sim,FFTCase}
import ffte.evaluate.vectors.Cplx

import ffte.property.{getShiftMethod,FFTIOConfig,getFFTIW,getFFTOW}
import ffte.core.{StreamedGroupConfig}

import ffte.core.mixed.{MFFT,MIFFT}
import ffte.core.mixed
import ffte.algorithm.FFTGen

/**
 * Mixed-Radix FFT Hardware Simulation Test
 *
 * This test program validates the correctness of mixed-radix FFT hardware implementations.
 * Mixed-radix FFT combines different decomposition strategies (prime factorization,
 * power-of-2 decomposition) for optimal performance across various FFT sizes.
 *
 * Features:
 * - Mixed-radix decomposition with configurable grouping
 * - Streaming interface with ready/valid handshaking
 * - Internal pipelining for improved throughput
 * - Support for various group sizes and FFT configurations
 * - Comprehensive test vector validation with frame-based processing
 *
 * Architecture:
 * - Groups input data for parallel processing
 * - Uses mixed-radix algorithms for composite sizes
 * - Internal pipeline stages for optimal resource usage
 * - Streaming output with configurable latency
 *
 * Command Line Arguments:
 * args(0): Group size (default 2)
 * args(1): FFT size per group (required for single test)
 * args(2): Scaling factor (required for single test)
 * args(3): Input bit width (optional, default 18)
 * args(4): Output bit width (optional, default 18)
 *
 * @author FFTE Project
 * @since 2021
 *
 * @example
 * // Test mixed-radix FFT with group size 4, 64 points per group
 * runMain test.MixedStudy 4 64 3
 *
 * @see [[ffte.core.mixed.MFFT]] for mixed-radix FFT implementation
 * @see [[ffte.evaluate.cases.Sim]] for simulation framework
 */
object MixedStudy extends Sim {

    /** Input data bit width configuration */
    var dW   = getFFTIW()

    /** Output data bit width configuration */
    var oW   = getFFTOW()

    /**
     * Always return true for valid signal (continuous testing mode).
     *
     * @return true for continuous data flow
     */
    def valid() = true

    /**
     * Always return true for ready signal (continuous testing mode).
     *
     * @return true for continuous data reception
     */
    def ready() = true
    /**
     * Mixed-radix FFT simulation function
     *
     * Executes comprehensive testing for mixed-radix FFT implementations with
     * specified group size and FFT size per group. Tests streaming interface
     * with frame-based data processing.
     *
     * @param g Group size for parallel processing
     * @param n FFT size per group
     * @return Tuple of (overall test success, failure count)
     */
    def sim(g:Int,n:Int) : (Boolean,Int) = {
        N = n*g  // Total FFT size = group size × size per group
        C = 100*N
        S   = getShiftMethod(N) + (if(dW>oW) (dW-oW) else 0)
        
        var delay = 0
        val tv = test_vectors
        println(s"simulation $N gen test vectors ${getFFTIW()}")
        var ridx = 0
        var fok  = true
        
        SimConfig.compile {
            val fft = MFFT(g,n).asInstanceOf[mixed.dft]
            delay = fft.delay
            fft
        }.doSim{ dut =>
            val tab     = dut.in_tab
            val out_tab = dut.out_tab 
            done({
                idx = 0
                ridx = 0
                failcnt = 0
                dut.clockDomain.forkStimulus(period = 10)
                dut.io.q.ready        #= true
                dut.io.d.payload.last #= true
                dut.clockDomain.waitRisingEdge()
                dut.io.d.payload.last #= false
            },{
                val frame  = idx/n
                val iidx   = idx%n
                val tframe = frame%CC
                val rframe = (idx-delay)/n
                val d : Array[Cplx] = tv(tframe).d
                for(i <- 0 until g) {
                    dut.io.d.payload.fragment.d(i).re.d.raw #= d(tab(iidx)(i)).re
                    dut.io.d.payload.fragment.d(i).im.d.raw #= d(tab(iidx)(i)).im
                }
                dut.io.d.valid            #= valid()
                dut.io.q.ready            #= ready()
                if(iidx==n-1) {
                    dut.io.d.payload.last #= true
                } else {
                    dut.io.d.payload.last #= false
                } 
                dut.clockDomain.waitRisingEdge()
                if(dut.io.d.ready.toBoolean && dut.io.d.valid.toBoolean) {
                    next
                    if(rframe>1) {
                        for(i <- 0 until g) {
                            val q = tv(rframe%CC).q(out_tab(ridx%n)(i))
                            val r = Cplx(dut.io.q.payload.fragment.d(i).re.d.raw.toInt,dut.io.q.payload.fragment.d(i).im.d.raw.toInt)
                            if((r-q).norm1>32) {
                                if(debug>2) {
                                    println(s"$rframe($ridx,$i) $r != $q")
                                }
                                fok = false
                                ok  = false
                            }
                        }
                    }
                    if(rframe>=0) {
                        if(dut.io.q.payload.last.toBoolean) {
                            ridx = 0
                            if(!fok) failcnt += 1
                            fok = true
                        } else {
                            ridx += 1
                        }
                    }
                }
            })
        }
        (failcnt<(C-delay-n+1)/n,failcnt)
    }
    def singleTest(g:Int,n:Int,k:Int):Unit = {
        debug = 3
        K = k
        FFTGen.Winograd
        println(s"$n,$g($inv):single test")
        val (_ook,failcnt) = FFTIOConfig(dW,oW) on sim(g,n)
        if(_ook) println(s"$n,$g($inv): PASS($failcnt)") else println(s"$n: FAIL")
        println(s"shift:$S")
    }
    def caseTest(g:Int,cases:Map[Int,Int]): Unit = {
        K = 0
        debug = 1
        val nList = cases.map{ case(n,_) => n }.toList.sorted
        nList.map { n =>
            println(s"start $g,$n test")
            val (_ook,failcnt) = sim(g,n)
            val msg = if(_ook) s"$n,$g($inv): PASS($failcnt)" else s"$n: FAIL"
            Sim.log2file("tmp/mixed.log",msg)
            (n,_ook,failcnt)
        }.foreach{case (n,oook,failcnt) =>
            if(oook) println(s"$n($inv): PASS($failcnt)") else println(s"$n: FAIL")
        }
    }
    def main(args: Array[String]) {
        ffte.default.core.GlobeSetting.evaluateNative = false
        val g = if(args.length>0) args(0).toInt else 2
        if(g==0) {
            println("Group size g must be greater than 0")
            return
        }
        if(args.length>1) {
            val n = args(1).toInt
            val k = args(2).toInt
            if(args.length>3) {
                dW = args(3).toInt
            }
            if(args.length>4) {
                oW = args(4).toInt
            }
            singleTest(g,n,k)
        } else if(args.length==1) caseTest(g,FFTCase.cases)
        else {
            for(g <- FFTCase.prime.keys.toList.sorted.drop(2)) caseTest(g,FFTCase.cases)
        }
    }
}

object MixedGen {
    def main(args: Array[String]): Unit = {
        val g = args(0).toInt
        val n = args(1).toInt
        FFTGen.Winograd
        SpinalVerilog(MFFT(g,n))
    }
}